{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Paged Attention in cuDNN Frontend\n",
    "\n",
    "This notebook illustrates how the cuDNN's frontend scaled dot product attention operator can be used to supported paged attention. For a simpler introduction to the scaled dot product attention operator, please refer to [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)\n",
    "\n",
    "The full documentation of cuDNN's scaled dot production attention operator can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention). The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)\n",
    "\n",
    "More details on paged attention can be found in the [PagedAttention paper](https://arxiv.org/abs/2309.06180)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/52_scaled_dot_product_attention.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prerequisites and Setup\n",
    "This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and select a GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('pip install nvidia-cudnn-frontend')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import math\n",
    "\n",
    "torch.manual_seed(42)\n",
    "handle = cudnn.create_handle()\n",
    "\n",
    "assert torch.cuda.is_available()\n",
    "assert (\n",
    "    torch.cuda.get_device_capability()[0] >= 8\n",
    "), \"SDPA operation is only supported on SM80 architecture (Ampere) or above\"\n",
    "\n",
    "assert (\n",
    "    cudnn.backend_version() >= 90500\n",
    "), \"SDPA operation is only supported cuDNN version 9.5.0 or above\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Problem sizes and tensor setup\n",
    "\n",
    "For this example, we will use the same problem size as in  [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb).\n",
    "In addition we are setting the block_size for both K and V to 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the query, key, value, and output GPU tensors using PyTorch. However, the user may use any DLPack compatible tensor instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = 4  # batch size\n",
    "h = 12  # query number of heads\n",
    "s = 1024  # maximum sequence length\n",
    "d = 64  # embedding dimension per head\n",
    "\n",
    "block_size_k = block_size_v = (\n",
    "    64  # block size to be used by the non contiguous K/V containers\n",
    ")\n",
    "\n",
    "attn_scale = 1.0 / math.sqrt(d)\n",
    "\n",
    "# The tensors will have non-interleaved\n",
    "# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout\n",
    "# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout\n",
    "dims = (b, h, s, d)\n",
    "strides = (s * h * d, d, h * d, 1)\n",
    "\n",
    "q_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "k_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "v_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "o_gpu = torch.empty(b * s * h * d).half().cuda().as_strided(dims, strides)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create variable sequence length tensors. These are required when using paged K/V caches. To keep things simple, we set these to the maximum sequence length `s` in this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set to s for all batches, just for the notebook sample\n",
    "seq_len_q_gpu = torch.full((b, 1, 1, 1), s, device=\"cuda\")\n",
    "seq_len_kv_gpu = torch.full((b, 1, 1, 1), s, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "####  Generate containers and page tables for K and V\n",
    "\n",
    "In a real world scenario, container and page table tensors are generated by other parts of the model. For illustration purposes in this example, we provide a helper function to generate a trivial container from contiguous K and V caches. \n",
    "The helper function basically takes e.g., the K-cache and splits up the sequence (`S`) dimension in different blocks of length `block_size`. The resulting page table then helps identify which block belongs to which sequence ID."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper function to create a non contiguous container in blocks of block_size from a contiguous tensor\n",
    "def create_container_and_page_table(tensor, block_size):\n",
    "    B, H, S, D = tensor.shape\n",
    "    blocks_per_batch = math.ceil(S / block_size)\n",
    "\n",
    "    # This assertion keeps the helper function of this example simple, but is not a requirement for paged attention.\n",
    "    assert (blocks_per_batch * block_size) == S\n",
    "\n",
    "    # Create a container by splitting on the S dimension and concatenating at the block dimension\n",
    "    # Its dimensions are [num_blocks, H, block_size, D] with num_blocks = B * blocks_per_batch\n",
    "    container = torch.cat((tensor.clone()).chunk(blocks_per_batch, dim=2), dim=0)\n",
    "\n",
    "    # Create the page table\n",
    "    page_table = torch.linspace(\n",
    "        0,\n",
    "        B * blocks_per_batch - 1,\n",
    "        B * blocks_per_batch,\n",
    "        device=\"cuda\",\n",
    "        dtype=torch.int32,\n",
    "    ).reshape(blocks_per_batch, 1, B, 1)\n",
    "    page_table = torch.transpose(page_table, 0, 2)\n",
    "\n",
    "    return (container, page_table)\n",
    "\n",
    "\n",
    "# Create non contiguous containers with page tables for K and V from the contiguous k_gpu and v_gpu\n",
    "container_k_gpu, page_table_k_gpu = create_container_and_page_table(k_gpu, block_size_k)\n",
    "container_v_gpu, page_table_v_gpu = create_container_and_page_table(v_gpu, block_size_v)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Graph creation and execution\n",
    "\n",
    "Create the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = cudnn.pygraph(\n",
    "    io_data_type=cudnn.data_type.HALF,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "q = graph.tensor_like(q_gpu)\n",
    "\n",
    "container_k = graph.tensor_like(container_k_gpu)\n",
    "container_v = graph.tensor_like(container_v_gpu)\n",
    "page_table_k = graph.tensor_like(page_table_k_gpu)\n",
    "page_table_v = graph.tensor_like(page_table_v_gpu)\n",
    "\n",
    "seq_len_q = graph.tensor_like(seq_len_q_gpu)\n",
    "seq_len_kv = graph.tensor_like(seq_len_kv_gpu)\n",
    "\n",
    "o, _ = graph.sdpa(\n",
    "    name=\"sdpa\",\n",
    "    q=q,\n",
    "    k=container_k,  # Container K: non contiguous container with K blocks\n",
    "    v=container_v,  # Container V: non contiguous container with V blocks\n",
    "    is_inference=True,\n",
    "    attn_scale=attn_scale,\n",
    "    use_causal_mask=True,\n",
    "    use_padding_mask=True,\n",
    "    seq_len_q=seq_len_q,\n",
    "    seq_len_kv=seq_len_kv,\n",
    "    paged_attention_k_table=page_table_k,  # Page Table K: Tensor containing offsets to the container with K blocks\n",
    "    paged_attention_v_table=page_table_v,  # Page Table V: Tensor containing offsets to the container with V blocks\n",
    "    paged_attention_max_seq_len_kv=s,  # The maximum sequence length for K caches (this is optional, but recommended)\n",
    ")\n",
    "\n",
    "o.set_output(True).set_dim(dims).set_stride(strides)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.validate()\n",
    "graph.build_operation_graph()\n",
    "graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "graph.check_support()\n",
    "graph.build_plans()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_pack = {\n",
    "    q: q_gpu,\n",
    "    container_k: container_k_gpu,\n",
    "    container_v: container_v_gpu,\n",
    "    page_table_k: page_table_k_gpu,\n",
    "    page_table_v: page_table_v_gpu,\n",
    "    seq_len_q: seq_len_q_gpu,\n",
    "    seq_len_kv: seq_len_kv_gpu,\n",
    "    o: o_gpu,\n",
    "}\n",
    "\n",
    "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n",
    "graph.execute(variant_pack, workspace)\n",
    "torch.cuda.synchronize()\n",
    "\n",
    "cudnn.destroy_handle(handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run the PyTorch reference and compare against cuDNN's output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_ref = q_gpu.detach().float().requires_grad_()\n",
    "k_ref = k_gpu.detach().float().requires_grad_()\n",
    "v_ref = v_gpu.detach().float().requires_grad_()\n",
    "\n",
    "o_ref = torch.nn.functional.scaled_dot_product_attention(\n",
    "    q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale\n",
    ")\n",
    "\n",
    "torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
